{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from chemprop.nn.predictors import (\n",
    "    RegressionFFN,\n",
    "    BinaryClassificationFFN,\n",
    "    MulticlassClassificationFFN,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is example output of [aggregation](./aggregation.ipynb) for input to the predictor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_datapoints_in_batch = 2\n",
    "hidden_dim = 300\n",
    "example_aggregation_output = torch.randn(n_datapoints_in_batch, hidden_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feed forward network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The learned representation from message passing and aggregation is a vector like that of fixed representations. While other predictors like random forest could be used to make final predictions from this representation, Chemprop prefers and implements using a feed forward network as that allows for end-to-end training. Three basic Chemprop FFNs differ in the prediction task they are used for. Note that multiclass classification needs to know the number of classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "regression_ffn = RegressionFFN()\n",
    "binary_class_ffn = BinaryClassificationFFN()\n",
    "multi_class_ffn = MulticlassClassificationFFN(n_classes=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Input dimension"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The default input dimension of the predictor is the same as the default dimension of the message passing hidden representation. If your message passing hidden dimension is different, or if you have addition atom or datapoint descriptors, you need to change the predictor's input dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2080],\n",
       "        [0.2787]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ffn = RegressionFFN()\n",
    "ffn(example_aggregation_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0877],\n",
       "        [-0.2629]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mp_hidden_dim = 2\n",
    "n_atom_descriptors = 1\n",
    "mp_output = torch.randn(n_datapoints_in_batch, mp_hidden_dim + n_atom_descriptors)\n",
    "example_datapoint_descriptors = torch.randn(n_datapoints_in_batch, 12)\n",
    "\n",
    "input_dim = mp_output.shape[1] + example_datapoint_descriptors.shape[1]\n",
    "\n",
    "ffn = RegressionFFN(input_dim=input_dim)\n",
    "ffn(torch.cat([mp_output, example_datapoint_descriptors], dim=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Output dimension"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The number of tasks defaults to 1 but can be adjusted. Predictors that need to predict multiple values per task, like multiclass classification, will automatically adjust the output dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ffn = RegressionFFN(n_tasks=4)\n",
    "ffn(example_aggregation_output).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 3])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ffn = MulticlassClassificationFFN(n_tasks=4, n_classes=3)\n",
    "ffn(example_aggregation_output).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Customization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following hyperparameters of the predictor are customizable:\n",
    "\n",
    " - the hidden dimension between layer, default: 300\n",
    " - the number of layer, default 1\n",
    " - the dropout probability, default: 0.0 (i.e. no dropout)\n",
    " - which activation function, default: ReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0121],\n",
       "        [-0.0760]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_ffn = RegressionFFN(hidden_dim=600, n_layers=3, dropout=0.1, activation=\"tanh\")\n",
    "custom_ffn(example_aggregation_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Intermediate hidden representations can also be extracted. Note that each predictor layer consists of an activation layer, followed by dropout, followed by a linear layer. The first predictor layer only has the linear layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 600])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer = 2\n",
    "custom_ffn.encode(example_aggregation_output, i=layer).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RegressionFFN(\n",
       "  (ffn): MLP(\n",
       "    (0): Sequential(\n",
       "      (0): Linear(in_features=300, out_features=600, bias=True)\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (0): Tanh()\n",
       "      (1): Dropout(p=0.1, inplace=False)\n",
       "      (2): Linear(in_features=600, out_features=600, bias=True)\n",
       "    )\n",
       "    (2): Sequential(\n",
       "      (0): Tanh()\n",
       "      (1): Dropout(p=0.1, inplace=False)\n",
       "      (2): Linear(in_features=600, out_features=600, bias=True)\n",
       "    )\n",
       "    (3): Sequential(\n",
       "      (0): Tanh()\n",
       "      (1): Dropout(p=0.1, inplace=False)\n",
       "      (2): Linear(in_features=600, out_features=1, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (criterion): MSE(task_weights=[[1.0]])\n",
       "  (output_transform): Identity()\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_ffn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Criterion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each predictor has a criterion that is used as the [loss function](../loss_functions.ipynb) during training. The default criterion for a predictor is defined in the predictor class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'chemprop.nn.metrics.MSE'>\n",
      "<class 'chemprop.nn.metrics.BCELoss'>\n",
      "<class 'chemprop.nn.metrics.CrossEntropyLoss'>\n"
     ]
    }
   ],
   "source": [
    "print(RegressionFFN._T_default_criterion)\n",
    "print(BinaryClassificationFFN._T_default_criterion)\n",
    "print(MulticlassClassificationFFN._T_default_criterion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A custom criterion can be given to the predictor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn import MSE\n",
    "\n",
    "criterion = MSE(task_weights=torch.tensor([0.5, 1.0]))\n",
    "ffn = RegressionFFN(n_tasks=2, criterion=criterion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression vs. classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to using different loss functions, regression and classification predictors also differ in their tranforms of the model outputs during inference. \n",
    "\n",
    "Regression should use a [scaler transform](../scaling.ipynb) if target normalization is used during training.\n",
    "\n",
    "Classification uses a sigmoid (for binary classification) or a softmax (for multiclass) transform to keep class probability predictions between 0 and 1. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "probs = binary_class_ffn(example_aggregation_output)\n",
    "(0 < probs).all() and (probs < 1).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Other predictors coming soon"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Beta versions of predictors for uncertainty and spectral tasks will be finalized in v2.1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.predictors import (\n",
    "    MveFFN,\n",
    "    EvidentialFFN,\n",
    "    BinaryDirichletFFN,\n",
    "    MulticlassDirichletFFN,\n",
    "    SpectralFFN,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
